/*
    Copyright 2007-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/** \file

    Definitions of the mcrx_stage helper functions.  \ingroup mcrx */

#ifndef __mcrx_stage_impl__
#define __mcrx_stage_impl__

#include "mcrx-stage.h"
#include "grain_model_hack.h"
#include "mpi_util.h"

template <template <typename> class grid_type, typename stage_type> 
bool
mcrx::mcrx_stage<grid_type, stage_type>::check_stage_state ()
{
  bool state = false;
  try {
    output_file_->pHDU().readKey("MCRX" + stage().stage_ID(), state);
  }
  catch (CCfits::HDU::NoSuchKeyword&) {}
  return state;
}


/** Resets a stage to uncompleted status. This is needed whenever we
    revisit a run to add a new stage or update the numbers of rays. */
template <template <typename> class grid_type, typename stage_type> 
void
mcrx::mcrx_stage<grid_type, stage_type>::
set_stage_state(bool state, const std::string& stage_id)
{
  if(!is_mpi_master())
    return;

  bool current_state;
  std::string keyname = "MCRX" + stage_id;
  try {
    output_file_->pHDU().readKey(keyname, current_state);
  }
  catch (CCfits::HDU::NoSuchKeyword&) {
    // state doesn't exist. in that case we add it if we must
    if(state)
      output_file_->pHDU().addKey(keyname, state, "");
    return;
  }
  if(state!=current_state) {
    // stage has wrong state. update
    CCfits::Keyword& k=output_file_->pHDU().keyWord(keyname);
    k.setValue(state);
  }  
}



template <template <typename> class grid_type, typename stage_type> 
void 
mcrx::mcrx_stage<grid_type, stage_type>::run_stage()
{
  using namespace std;

  if (m_.terminator())
    return;

  // we need a barrier here because otherwise one task might start
  // reading the output file before the master task has finished
  // writing it in the previous stage.
  mpi_barrier();

  cout << "\n  *** ENTERING " << stage ().stage_name ()
       << " SHOOTING STAGE ***\n" << endl;

  // See how many rays are requested.
  n_rays_desired_ = stage ().get_rays_desired();
  if(stage().skip()) {
    cout << "Skipping " << stage ().stage_name()
	 << ", 0 rays requested" << endl;
    return;
  }

  // open output file (read-only)
  const string output_file_name = 
    word_expand (p_.getValue("output_file", string ())) [0];
  output_file_.reset(new CCfits::FITS (output_file_name, CCfits::Read));

  // check if stage previously completed and how many rays have been shot 
  bool complete = check_stage_state ();
  n_rays_ = get_rays_completed ();

  if (complete && (n_rays_ >= n_rays_desired_)) {
    cout << "  Stage already completed" << endl;
    return;
  }
  else {
    cout << n_rays_ << " rays previously shot, want "
	 << n_rays_desired_ << endl;
  }

  // open input file, and reopen output file read/write
  const string input_file_name = 
    word_expand (p_.getValue("input_file", string ())) [0];
  input_file_.reset (new CCfits::FITS(input_file_name, CCfits::Read));
  output_file_->destroy();
  output_file_.reset();
  if(is_mpi_master()) {
    output_file_.reset(new CCfits::FITS (output_file_name, CCfits::Write));
  }
  else
    output_file_.reset(new CCfits::FITS (output_file_name, CCfits::Read));

  // set the postprocessing stages to not completed so that if we
  // reenter after completing them they are redone
  set_stage_state(false, "_PP1");
  set_stage_state(false, "_PP2");
  
  read_units (m_.units);

  // we always load the states first, then if there is a dump file
  // we overwrite
  m_.load_rng_states(*output_file_);

  // check for dump file
  const string dump_file_name = 
    word_expand (p_.getValue("output_file", string ()))[0] +
    "."  + stage ().stage_ID() + ".mcrxdump" + mpi_rank_string();
  binifstream dump_file (dump_file_name);
  const bool dump = dump_file.good();

  // Set up objects. There are 3 possibilities. If dump is set, we
  // load from dump. If complete is set, we load from the output
  // file. Otherwise we set up for a new run.
  if (dump) {
    dump_file >> n_rays_;
    cout << "Loading previously saved data from dump file: "
	 << dump_file_name << " containing " << n_rays_ << " rays" << endl;
    stage ().load_dump(dump_file );
    m_.load_rng_states(dump_file);

    // check that the file length is what was expected
    char c;
    if (!(dump_file.good() && dump_file.get(c ).eof())) {
      cerr << "FATAL: error reading dump file!"  << endl;
      m_.kill();
      return;
    }
    unlink (dump_file_name.c_str());
  }
  else if (complete) {
    stage ().load_file();
  }
  else {
    stage ().setup_objects ();
  }

  dump_file.close();

  if (!m_.terminator()) {
    // if we did setup successfully, continue with shooting. otherwise
    // fall through to dumping
    cout  << "Shooting: ";
    if (!stage ().shoot()) {
      // shoot returned false, which means we ran until completion.
      cout << "Shooting complete.  Saving."  << endl;
      stage ().save_file();
      m_.save_rng_states(*output_file_ );
      update_rays_completed(n_rays_);
    }
  }
  
  if (m_.terminator()) {
    // We got the termination signal, either during shoting or during
    // setup.  Dump and exit.

    /// \todo how do we deal with this for mpi?
    assert(0);
    
    cout << "Shooting terminated.  Dumping data to file " 
	 << dump_file_name << endl;
    binofstream dump_outfile (dump_file_name);
    dump_outfile << n_rays_;
    stage ().save_dump(dump_outfile );
    m_.write_rng_states (dump_outfile);
  }
  
  // write whether we completed or not.
  stage ().set_stage_state(!m_.terminator(), stage().stage_ID());

  // destroy objects
  output_file_.reset();
  input_file_.reset();
  g_.reset();
  eme_.reset();
  emi_.reset();
  model_.reset();

  if (m_.terminator())
    cout << "Stage " << stage ().stage_name()
	 << " terminated, restart to complete."  << endl;
  else
    cout << "Stage " << stage ().stage_name()
	 << " complete." << endl;
};


/** Reads the units from the input file into the unit map. */
template <template <typename> class grid_type, typename stage_type> 
void 
mcrx::mcrx_stage<grid_type, stage_type>::read_units(T_unit_map& units)
{
   using namespace CCfits; 
  ExtHDU& particle_hdu = open_HDU(*input_file_, "PARTICLEDATA");
  ExtHDU& lambda_hdu = open_HDU(*input_file_, "LAMBDA");

  units ["length"] = particle_hdu.column("position").unit();
  Column& c_mass = particle_hdu.column("mass");
  units ["mass"] = c_mass.unit();
  units ["time"] = particle_hdu.column("age").unit();
  Column& c_l_bol = particle_hdu.column("L_bol");
  units ["luminosity"] = c_l_bol.  unit ();
  Column& c_l_lambda = particle_hdu.column("L_lambda");
  units ["L_lambda"]  = c_l_lambda.unit();
  units["wavelength"] = lambda_hdu.column("lambda").unit();

  try {
    ExtHDU& data_hdu = open_HDU(*input_file_, "GRIDDATA");
    data_hdu.readKey("tempunit",units["temperature"]);
  }
  catch (FITS::NoSuchHDU&) {
    // no grid. in that case we just decide we will use K as temp unit
    units ["temperature"] = "K";
  }

  assert(units.get("L_lambda")==
	 units.get("luminosity")+"/"+units.get("wavelength"));

}

/** Performs the shooting for all the nonscatter stages, including the
    aux ones. They are all the same in terms of operations so this
    advantageous.  This function is trivial since all the
    multithreading functionality has been wrapped into the general
    shoot function.  */
template <template <typename> class grid_type, typename stage_type> 
bool
mcrx::mcrx_stage<grid_type, stage_type>::shoot_nonscatter ()
{
  const int n_threads = p_.getValue("n_threads", int(1), true);
  const bool bind_threads = p_.getValue("bind_threads", bool(false), true);
  const int bank_size = p_.getValue("thread_bank_size", int(n_threads), true);

  if (n_threads != m_.rng_states.size()) {
    std::cerr << "Number of threads requested does not equal the number of initialized random states.\nDid you change n_threads without removing the output file?";
    exit(1);
  }

  // Create dummy objects that are needed for the xfer constructor
  dummy_grid g;
  T_dust_model model;
  // The xfer rngs are loaded in the thread shooting function.
  mcrx_rng_policy rng;

  // For this purpose, we can use the default xfer values 
  xfer<T_dust_model, dummy_grid> x 
    (g, *emi_, *eme_, model, rng, stage().biaser());
  T_shooter s(stage().shooter());

  return shoot (x, s, mcrx_terminator (m_), m_.rng_states, 
		n_rays_desired_, n_rays_, n_threads, bind_threads, bank_size,
		11, stage().stage_name());
}


template <template <typename> class grid_type, typename stage_type> 
bool
mcrx::mcrx_stage<grid_type, stage_type>::shoot_scatter (bool add_intensity)
{
  const int n_threads = p_.getValue("n_threads", int(1), true);
  const bool bind_threads = p_.getValue("bind_threads", bool(false), true);
  const int bank_size = p_.getValue("thread_bank_size", int(n_threads), true);

  if (n_threads != m_.rng_states.size()) {
    std::cerr << "Number of threads requested does not equal the number of initialized random states.\nDid you change n_threads without removing the output file?";
    exit(1);
  }

  // We force the user to be explicit about the radiative transfer
  // settings used
  const T_float i_min = p_.getValue("i_min", T_float ());
  const T_float i_max = p_.getValue("i_max", T_float ());
  const int minscat = p_.getValue("n_scatter_min", int(0), true);
  const int maxscat = p_.getValue("n_scatter_max", blitz::huge(int()), true);

  // The xfer rngs are loaded in the thread shooting function.
  mcrx_rng_policy rng;

  xfer<T_dust_model, T_grid> x (*g_, *emi_, *eme_, *model_,
				rng, stage().biaser(),
				i_min, i_max, 
				add_intensity, false, minscat, maxscat);

  T_shooter s(stage().shooter());
  return shoot (x, s, mcrx_terminator (m_), m_.rng_states, 
		n_rays_desired_, n_rays_, n_threads, bind_threads, bank_size,
		10, stage().stage_name());
}

template <template <typename> class grid_type, typename stage_type> 
bool
mcrx::mcrx_stage<grid_type, stage_type>::shoot_integration ()
{
  const int n_threads = p_.getValue("n_threads", int(1), true);
  const bool bind_threads = p_.getValue("bind_threads", false, true);
  const int bank_size = p_.getValue("thread_bank_size", int(n_threads), true);
  const int rays_per_cell = p_.getValue("rays_per_cell", int(1), true);

  if (n_threads != m_.rng_states.size()) {
    std::cerr << "Number of threads requested does not equal the number of initialized random states.\nDid you change n_threads without removing the output file?";
    exit(1);
  }

  // The xfer rngs are loaded in the thread shooting function.
  mcrx_rng_policy rng;

  xfer<T_dust_model, T_grid> x (*g_, *emi_, *eme_, *model_,
				rng, stage().biaser(),
				1,1, false, false);
  x.set_integrations_per_cell(rays_per_cell);

  integration_shooter s(integration_shooter::intensity);
  long dummy=0;
  shoot (x, s, mcrx_terminator (m_), m_.rng_states, 
	 0, dummy, n_threads, bind_threads, bank_size,
	 10, stage().stage_name());

  // because the image values when integrating directly are already in
  // surface brightness, we need to back convert them to what's output
  // by the shooting so that they are consistent. To do this we must
  // multiply them by the normalized area of a pixel.
  for(typename T_emergence::iterator c= eme_->begin();
      c != eme_->end(); ++c)
    (*c)->get_image() *= (*c)->pixel_normalized_area();
  
  return mcrx_terminator(m_)();
}


template <template <typename> class grid_type, typename stage_type> 
bool
mcrx::mcrx_stage<grid_type, stage_type>::shoot_emission_integration ()
{
  const int n_threads = p_.getValue("n_threads", int(1), true);
  const bool bind_threads = p_.getValue("bind_threads", false, true);
  const int bank_size = p_.getValue("thread_bank_size", int(n_threads), true);
  const int rays_per_cell = p_.getValue("rays_per_cell", int(1), true);

  if (n_threads != m_.rng_states.size()) {
    std::cerr << "Number of threads requested does not equal the number of initialized random states.\nDid you change n_threads without removing the output file?";
    exit(1);
  }

  // The xfer rngs are loaded in the thread shooting function.
  mcrx_rng_policy rng;

  xfer<T_dust_model, T_grid> x (*g_, *emi_, *eme_, *model_,
				rng, stage().biaser(),
				1,1, false, false);
  x.set_integrations_per_cell(rays_per_cell);

  integration_shooter s(integration_shooter::emission);
  long dummy=0;
  shoot (x, s, mcrx_terminator (m_), m_.rng_states, 
	 0, dummy, n_threads, bind_threads, bank_size,
	 10, stage().stage_name());

  // because the image values when integrating directly are already in
  // surface density, we need to back convert them to what's output
  // by the shooting so that they are consistent. To do this we must
  // multiply them by the normalized area of a pixel/4pi.
  for(typename T_emergence::iterator c= eme_->begin();
      c != eme_->end(); ++c)
    (*c)->get_image() *= blitz::scalar((*c)->pixel_normalized_area()/(4*constants::pi));
  
  return mcrx_terminator(m_)();
}


template <template <typename> class grid_type, typename stage_type> 
long 
mcrx::mcrx_stage<grid_type, stage_type>::get_rays_completed () 
{
  int temp=0; 
  CCfits::ExtHDU& hdu = open_HDU(*output_file_,"SCATTERING_LAMBDAS");
  try {
    hdu.readKey("NRAY" + stage ().stage_ID(), temp);
  }
  catch (...) {}
  return temp;
}


template <template <typename> class grid_type, typename stage_type> 
void 
mcrx::mcrx_stage<grid_type, stage_type>::
update_rays_completed (long n)
{
  if(is_mpi_master()) {
    cout << "Updating number of rays completed: " << n << endl;
    CCfits::ExtHDU& hdu = open_HDU(*output_file_, "SCATTERING_LAMBDAS");
    hdu.addKey("NRAY" + stage ().stage_ID(), n, 
	       "Number of rays shot for " + stage().stage_name() + "stage");
  }
}


/** Helper function that parses the preferences keywords to load a
    scatterer or grain_model. */
template<typename T_scatterer_vector>
T_scatterer_vector
mcrx::read_dust_grains(Preferences& p, const T_unit_map& units)
{
  T_scatterer_vector v;

  // currently, we only have a simple dust model so we only have to
  // worry about what type of dust grain we have
  //boost::shared_ptr<T_dust_model::T_scatterer> grain;

  // check if we are using an explicit grain model or just an opacity file
  if (p.defined ("grain_model")) {
    v.push_back(load_grain_model<polychromatic_scatterer_policy, mcrx_rng_policy>(p, units));
  }
  else {
    std::cout << "Reading dust data for " 
	      << word_expand (p.getValue("dust_grain_file", string ())) [0]
	      << '\n';

    // load a dust file
    try {
      v.push_back(typename T_scatterer_vector::value_type(new Draine_grain<polychromatic_scatterer_policy, mcrx_rng_policy> (word_expand (p.getValue("dust_grain_file", string ())) [0] )));
    }
    catch (...) {
      try {
	v.push_back(typename T_scatterer_vector::value_type(new HG_dust_grain<polychromatic_scatterer_policy, mcrx_rng_policy> (word_expand (p.getValue("dust_grain_file", string ())) [0] )));
      }
      catch (...) {
	// we could not find a type of scatterer that worked
	std::cerr << "Fatal: Could not load a dust grain" << endl;
	throw;
      }
    }
  }

  return v;
}

#endif
